import argparse
import time
from pathlib import Path

import cv2
import numpy
import torch
import torch.backends.cudnn as cudnn
from numpy import random

from yolov5.models.experimental import attempt_load
from yolov5.utils.datasets import LoadStreams, LoadImages, letterbox
from yolov5.utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
    scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path
from yolov5.utils.plots import plot_one_box
from yolov5.utils.torch_utils import select_device, load_classifier, time_synchronized


def detect_son(img,img0,model,device,imgsz,half):
    im0s = img0
    # Get names and colors
    names = model.module.names if hasattr(model, 'module') else model.names
    colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]

    # Run inference
    if device.type != 'cpu':
        model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters())))  # run once


    img = torch.from_numpy(img).to(device)
    img = img.half() if half else img.float()  # uint8 to fp16/32
    img /= 255.0  # 0 - 255 to 0.0 - 1.0
    if img.ndimension() == 3:
        img = img.unsqueeze(0)

    # Inference
    t1 = time_synchronized()
    pred = model(img, augment=opt.augment)[0]

    # Apply NMS
    pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)


    # Process detections
    for i, det in enumerate(pred):  # detections per image
        if len(det):
            # Rescale boxes from img_size to im0 size
            det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0s.shape).round()

            for *xyxy, conf, cls in reversed(det):
                label = f'{names[int(cls)]} {conf:.2f}'
                im0s = plot_one_box(xyxy, im0s, label=label, color=colors[int(cls)], line_thickness=3)
                save_name = str(time.time())
            cv2.imwrite('./result_imgs/'+save_name+'.png',im0s)
    return im0s



def detect(camera_num):
    source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size

    # Initialize
    set_logging()
    device = select_device(opt.device)
    half = device.type != 'cpu'  # half precision only supported on CUDA

    # Load model
    model = attempt_load(weights, map_location=device)  # load FP32 model
    stride = int(model.stride.max())  # model stride
    imgsz = check_img_size(imgsz, s=stride)  # check img_size
    if half:
        model.half()  # to FP16


    # Set Dataloader
    cudnn.benchmark = True  # set True to speed up constant image size inference
    class UIControl:
        def __init__(self):
            self.mode = 'init'  # init, select, track
            self.target_tl = (-1, -1)
            self.target_br = (-1, -1)
            self.new_init = False

        def get_tl(self):
            return self.target_tl if self.target_tl[0] < self.target_br[0] else self.target_br

        def get_br(self):
            return self.target_br if self.target_tl[0] < self.target_br[0] else self.target_tl

        def get_bb(self):
            tl = self.get_tl()
            br = self.get_br()

            bb = [min(tl[0], br[0]), min(tl[1], br[1]), abs(br[0] - tl[0]), abs(br[1] - tl[1])]
            return bb

    cap = cv2.VideoCapture(camera_num)  # 相机编号 ，默认是0
    display_name = 'Display: Detection'
    cv2.namedWindow(display_name, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
    cv2.resizeWindow(display_name, 960, 720)
    fps, w, h = 30, 640, 480
    save_path = './result.mp4'
    vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))

    while True:
        # Capture frame-by-frame
        ret, frame = cap.read()
        w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        frame_disp = frame.copy()
        font_color = (0, 0, 0)
        image_size = 640

        img = letterbox(frame_disp, image_size, stride=stride)[0]
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
        img = numpy.ascontiguousarray(img)

        frame_disp = detect_son(img,frame_disp,model,device,imgsz,half)
        cv2.putText(frame_disp, 'Press q to quit', (20, 85), cv2.FONT_HERSHEY_COMPLEX_SMALL, 1,
                    font_color, 1)

        cv2.imshow(display_name, frame_disp)

        key = cv2.waitKey(1)
        if key == ord('q'):
            break
        vid_writer.write(frame_disp)

    cap.release()
    cv2.destroyAllWindows()


















if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', nargs='+', type=str, default='weights/last.pt', help='model.pt path(s)')
    parser.add_argument('--source', type=str, default=0, help='source') #  0 for webcam 0, 1 for video
    # parser.add_argument('--source', type=str, default='./smoke_test.mp4', help='source') #  0 for webcam 0, 1 for video
    parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
    parser.add_argument('--conf-thres', type=float, default=0.65, help='object confidence threshold')
    parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS')
    parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    parser.add_argument('--view-img', action='store_true', help='display results')
    parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
    parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
    parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
    parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
    parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
    parser.add_argument('--augment', action='store_true', help='augmented inference')
    parser.add_argument('--update', action='store_true', help='update all models')
    parser.add_argument('--project', default='runs/detect', help='save results to project/name')
    parser.add_argument('--name', default='exp', help='save results to project/name')
    opt = parser.parse_args()
    print(opt)
    check_requirements(exclude=('pycocotools', 'thop'))

    with torch.no_grad():

        detect(camera_num=opt.source)
